{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Document Embedding with Amazon SageMaker Object2Vec"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1. [Introduction](#Introduction)\n",
    "2. [Background](#Background)\n",
    "  1. [Embedding documents using Object2Vec](#Embedding-documents-using-Object2Vec)\n",
    "3. [Download and preprocess Wikipedia data](#Download-and-preprocess-Wikipedia-data)\n",
    "  1. [Install and load dependencies](#Install-and-load-dependencies)\n",
    "  2. [Build vocabulary and tokenize datasets](#Build-vocabulary-and-tokenize-datasets)\n",
    "  3. [Upload preprocessed data to S3](#Upload-preprocessed-data-to-S3)\n",
    "4. [Define SageMaker session, Object2Vec image, S3 input and output paths](#Define-SageMaker-session,-Object2Vec-image,-S3-input-and-output-paths)\n",
    "5. [Train and deploy doc2vec](#Train-and-deploy-doc2vec)\n",
    "  1. [Learning performance boost with new features](#Learning-performance-boost-with-new-features)\n",
    "  2. [Training speedup with sparse gradient update](#Training-speedup-with-sparse-gradient-update)\n",
    "6. [Apply learned embeddings to document retrieval task](#Apply-learned-embeddings-to-document-retrieval-task)\n",
    "  1. [Comparison with the StarSpace algorithm](#Comparison-with-the-StarSpace-algorithm)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Introduction"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this notebook, we introduce four new features to Object2Vec, a general-purpose neural embedding algorithm: negative sampling, sparse gradient update, weight-sharing, and comparator operator customization. The new features together broaden the applicability of Object2Vec, improve its training speed and accuracy, and provide users with greater flexibility. See [Introduction to the Amazon SageMaker Object2Vec](https://aws.amazon.com/blogs/machine-learning/introduction-to-amazon-sagemaker-object2vec/) if you aren’t already familiar with Object2Vec.\n",
    "\n",
    "We demonstrate how these new features extend the applicability of Object2Vec to a new Document Embedding use-case: A customer has a large collection of documents. Instead of storing these documents in its raw format or as sparse bag-of-words vectors, to achieve training efficiency in the various downstream tasks, she would like to instead embed all documents in a common low-dimensional space, so that the semantic distance between these documents are preserved."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Background"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Object2Vec is a highly customizable multi-purpose algorithm that can learn embeddings of pairs of objects. The embeddings are learned such that it preserves their pairwise similarities in the original space.\n",
    "\n",
    "- Similarity is user-defined: users need to provide the algorithm with pairs of objects that they define as similar (1) or dissimilar (0); alternatively, the users can define similarity in a continuous sense (provide a real-valued similarity score).\n",
    "\n",
    "- The learned embeddings can be used to efficiently compute nearest neighbors of objects, as well as to visualize natural clusters of related objects in the embedding space. In addition, the embeddings can also be used as features of the corresponding objects in downstream supervised tasks such as classification or regression."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Embedding documents using Object2Vec"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We demonstrate how, with the new features, Object2Vec can be used to embed a large collection of documents into vectors in the same latent space.\n",
    "\n",
    "Similar to the widely used Word2Vec algorithm for word embedding, a natural approach to document embedding is to preprocess documents as (sentence, context) pairs, where the sentence and its matching context come from the same document. The matching context is the entire document with the given sentence removed. The idea is to embed both sentence and context into a low dimensional space such that their mutual similarity is maximized, since they belong to the same document and therefore should be semantically related. The learned encoder for the context can then be used to encode new documents into the same embedding space. In order to train the encoders for sentences and documents, we also need negative (sentence, context) pairs so that the model can learn to discriminate between semantically similar and dissimilar pairs. It is easy to generate such negatives by pairing sentences with documents that they do not belong to. Since there are many more negative pairs than positives in naturally occurring data, we typically resort to random sampling techniques to achieve a balance between positive and negative pairs in the training data. The figure below shows pictorially how the positive pairs and negative pairs are generated from unlabeled data for the purpose of learning embeddings for documents (and sentences)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"doc_embedding_illustration.png\" width=\"800\">"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We show how Object2Vec with the new *negative sampling feature* can be applied to the document embedding use-case. In addition, we show how the other new features, namely, *weight-sharing*, *customization of comparator operator*, and *sparse gradient update*, together enhance the algorithm's performance and user-experience in and beyond this use-case. Sections [Learning performance boost with new features](#Learning-performance-boost-with-new-features) and [Training speedup with sparse gradient update](#Training-speedup-with-sparse-gradient-update) in this notebook provide a detailed introduction to the new features."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Download and preprocess Wikipedia data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Please be aware of the following requirements about the acknowledgment, copyright and availability, cited from the [data source description page](https://github.com/facebookresearch/StarSpace/blob/master/LICENSE.md).\n",
    "\n",
    "> Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the \"Software\"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%bash\n",
    "\n",
    "DATANAME=\"wikipedia\"\n",
    "DATADIR=\"/tmp/wiki\"\n",
    "\n",
    "mkdir -p \"${DATADIR}\"\n",
    "\n",
    "if [ ! -f \"${DATADIR}/${DATANAME}_train250k.txt\" ]\n",
    "then\n",
    "    echo \"Downloading wikipedia data\"\n",
    "    wget --quiet -c \"https://dl.fbaipublicfiles.com/starspace/wikipedia_train250k.tgz\" -O \"${DATADIR}/${DATANAME}_train.tar.gz\"\n",
    "    tar -xzvf \"${DATADIR}/${DATANAME}_train.tar.gz\" -C \"${DATADIR}\"\n",
    "    wget --quiet -c \"https://dl.fbaipublicfiles.com/starspace/wikipedia_devtst.tgz\" -O \"${DATADIR}/${DATANAME}_test.tar.gz\"\n",
    "    tar -xzvf \"${DATADIR}/${DATANAME}_test.tar.gz\" -C \"${DATADIR}\"\n",
    "fi\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "datadir = '/tmp/wiki'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!ls /tmp/wiki"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Install and load dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install jsonlines"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# note: please run on python 3 kernel\n",
    "\n",
    "import os\n",
    "import random\n",
    "\n",
    "import math\n",
    "import scipy\n",
    "import numpy as np\n",
    "\n",
    "import re\n",
    "import string\n",
    "import json, jsonlines\n",
    "\n",
    "from collections import defaultdict\n",
    "from collections import Counter\n",
    "\n",
    "from itertools import chain, islice\n",
    "\n",
    "from nltk.tokenize import TreebankWordTokenizer\n",
    "from sklearn.preprocessing import normalize\n",
    "\n",
    "## sagemaker api\n",
    "import sagemaker, boto3\n",
    "from sagemaker.session import s3_input\n",
    "from sagemaker.predictor import json_serializer, json_deserializer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Build vocabulary and tokenize datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "BOS_SYMBOL = \"<s>\"\n",
    "EOS_SYMBOL = \"</s>\"\n",
    "UNK_SYMBOL = \"<unk>\"\n",
    "PAD_SYMBOL = \"<pad>\"\n",
    "PAD_ID = 0\n",
    "TOKEN_SEPARATOR = \" \"\n",
    "VOCAB_SYMBOLS = [PAD_SYMBOL, UNK_SYMBOL, BOS_SYMBOL, EOS_SYMBOL]\n",
    "\n",
    "\n",
    "##### utility functions for preprocessing\n",
    "def get_article_iter_from_file(fname):\n",
    "    with open(fname) as f:\n",
    "        for article in f:\n",
    "            yield article\n",
    "\n",
    "def get_article_iter_from_channel(channel, datadir='/tmp/wiki'):\n",
    "    if channel == 'train':\n",
    "        fname = os.path.join(datadir, 'wikipedia_train250k.txt')\n",
    "        return get_article_iter_from_file(fname)\n",
    "    else:\n",
    "        iterlist = []\n",
    "        suffix_list = ['train250k.txt', 'test10k.txt', 'dev10k.txt', 'test_basedocs.txt']\n",
    "        for suffix in suffix_list:\n",
    "            fname = os.path.join(datadir, 'wikipedia_'+suffix)\n",
    "            iterlist.append(get_article_iter_from_file(fname))\n",
    "        return chain.from_iterable(iterlist)\n",
    "\n",
    "\n",
    "def readlines_from_article(article):\n",
    "    return article.strip().split('\\t')\n",
    "\n",
    "\n",
    "def sentence_to_integers(sentence, word_dict, trim_size=None):\n",
    "    \"\"\"\n",
    "    Converts a string of tokens to a list of integers\n",
    "    \"\"\"\n",
    "    if not trim_size:\n",
    "        return [word_dict[token] if token in word_dict else 0 for token in get_tokens_from_sentence(sentence)]\n",
    "    else:\n",
    "        integer_list = []\n",
    "        for token in get_tokens_from_sentence(sentence):\n",
    "            if len(integer_list) < trim_size:\n",
    "                if token in word_dict:\n",
    "                    integer_list.append(word_dict[token])\n",
    "                else:\n",
    "                    integer_list.append(0)\n",
    "            else:\n",
    "                break\n",
    "        return integer_list\n",
    "\n",
    "\n",
    "def get_tokens_from_sentence(sent):\n",
    "    \"\"\"\n",
    "    Yields tokens from input string.\n",
    "\n",
    "    :param line: Input string.\n",
    "    :return: Iterator over tokens.\n",
    "    \"\"\"\n",
    "    for token in sent.split():\n",
    "        if len(token) > 0:\n",
    "            yield normalize_token(token)\n",
    "\n",
    "\n",
    "def get_tokens_from_article(article):\n",
    "    iterlist = []\n",
    "    for sent in readlines_from_article(article):\n",
    "        iterlist.append(get_tokens_from_sentence(sent))\n",
    "    return chain.from_iterable(iterlist)\n",
    "\n",
    "\n",
    "def normalize_token(token):\n",
    "    token = token.lower()\n",
    "    if all(s.isdigit() or s in string.punctuation for s in token):\n",
    "        tok = list(token)\n",
    "        for i in range(len(tok)):\n",
    "            if tok[i].isdigit():\n",
    "                tok[i] = '0'\n",
    "        token = \"\".join(tok)\n",
    "    return token"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# function to build vocabulary\n",
    "\n",
    "def build_vocab(channel, num_words=50000, min_count=1, use_reserved_symbols=True, sort=True):\n",
    "    \"\"\"\n",
    "    Creates a vocabulary mapping from words to ids. Increasing integer ids are assigned by word frequency,\n",
    "    using lexical sorting as a tie breaker. The only exception to this are special symbols such as the padding symbol\n",
    "    (PAD).\n",
    "\n",
    "    :param num_words: Maximum number of words in the vocabulary.\n",
    "    :param min_count: Minimum occurrences of words to be included in the vocabulary.\n",
    "    :return: word-to-id mapping.\n",
    "    \"\"\"\n",
    "    vocab_symbols_set = set(VOCAB_SYMBOLS)\n",
    "    raw_vocab = Counter()\n",
    "    for article in get_article_iter_from_channel(channel):\n",
    "        article_wise_vocab_list = list()\n",
    "        for token in get_tokens_from_article(article):\n",
    "            if token not in vocab_symbols_set:\n",
    "                article_wise_vocab_list.append(token)\n",
    "        raw_vocab.update(article_wise_vocab_list)\n",
    "\n",
    "    print(\"Initial vocabulary: {} types\".format(len(raw_vocab)))\n",
    "\n",
    "    # For words with the same count, they will be ordered reverse alphabetically.\n",
    "    # Not an issue since we only care for consistency\n",
    "    pruned_vocab = sorted(((c, w) for w, c in raw_vocab.items() if c >= min_count), reverse=True)\n",
    "    print(\"Pruned vocabulary: {} types (min frequency {})\".format(len(pruned_vocab), min_count))\n",
    "\n",
    "    # truncate the vocabulary to fit size num_words (only includes the most frequent ones)\n",
    "    vocab = islice((w for c, w in pruned_vocab), num_words)\n",
    "\n",
    "    if sort:\n",
    "        # sort the vocabulary alphabetically\n",
    "        vocab = sorted(vocab)\n",
    "    if use_reserved_symbols:\n",
    "        vocab = chain(VOCAB_SYMBOLS, vocab)\n",
    "\n",
    "    word_to_id = {word: idx for idx, word in enumerate(vocab)}\n",
    "\n",
    "    print(\"Final vocabulary: {} types\".format(len(word_to_id)))\n",
    "\n",
    "    if use_reserved_symbols:\n",
    "        # Important: pad symbol becomes index 0\n",
    "        assert word_to_id[PAD_SYMBOL] == PAD_ID\n",
    "    \n",
    "    return word_to_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# build vocab dictionary\n",
    "\n",
    "def build_vocabulary_file(vocab_fname, channel, num_words=50000, min_count=1, \n",
    "                          use_reserved_symbols=True, sort=True, force=False):\n",
    "    if not os.path.exists(vocab_fname) or force:\n",
    "        w_dict = build_vocab(channel, num_words=num_words, min_count=min_count, \n",
    "                             use_reserved_symbols=True, sort=True)\n",
    "        with open(vocab_fname, \"w\") as write_file:\n",
    "            json.dump(w_dict, write_file)\n",
    "\n",
    "channel = 'train'\n",
    "min_count = 5\n",
    "vocab_fname = os.path.join(datadir, 'wiki-vocab-{}250k-mincount-{}.json'.format(channel, min_count))\n",
    "\n",
    "build_vocabulary_file(vocab_fname, channel, num_words=500000, min_count=min_count, force=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Loading vocab file {} ...\".format(vocab_fname))\n",
    "\n",
    "with open(vocab_fname) as f:\n",
    "    w_dict = json.load(f)\n",
    "    print(\"The vocabulary size is {}\".format(len(w_dict.keys())))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Functions to build training data \n",
    "# Tokenize wiki articles to (sentence, document) pairs\n",
    "def generate_sent_article_pairs_from_single_article(article, word_dict):\n",
    "    sent_list = readlines_from_article(article)\n",
    "    art_len = len(sent_list)\n",
    "    idx = random.randint(0, art_len-1)\n",
    "    wrapper_text_idx = list(range(idx)) + list(range((idx+1) % art_len, art_len))\n",
    "    wrapper_text_list = sent_list[:idx] + sent_list[(idx+1) % art_len : art_len]\n",
    "    wrapper_tokens = []\n",
    "    for sent1 in wrapper_text_list:\n",
    "        wrapper_tokens += sentence_to_integers(sent1, word_dict)\n",
    "    sent_tokens = sentence_to_integers(sent_list[idx], word_dict)\n",
    "    yield {'in0':sent_tokens, 'in1':wrapper_tokens, 'label':1}\n",
    "\n",
    "\n",
    "def generate_sent_article_pairs_from_single_file(fname, word_dict):\n",
    "    with open(fname) as reader:\n",
    "        iter_list = []\n",
    "        for article in reader:\n",
    "            iter_list.append(generate_sent_article_pairs_from_single_article(article, word_dict))\n",
    "    return chain.from_iterable(iter_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build training data\n",
    "\n",
    "# Generate integer positive labeled data\n",
    "train_prefix = 'train250k'\n",
    "fname = \"wikipedia_{}.txt\".format(train_prefix)\n",
    "outfname = os.path.join(datadir, '{}_tokenized.jsonl'.format(train_prefix))\n",
    "counter = 0\n",
    "\n",
    "with jsonlines.open(outfname, 'w') as writer:\n",
    "    for sample in generate_sent_article_pairs_from_single_file(os.path.join(datadir, fname), w_dict):\n",
    "        writer.write(sample)\n",
    "        counter += 1\n",
    "        \n",
    "print(\"Finished generating {} data of size {}\".format(train_prefix, counter))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Shuffle training data\n",
    "!shuf {outfname} > {train_prefix}_tokenized_shuf.jsonl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Function to generate dev/test data (with both positive and negative labels)\n",
    "\n",
    "def generate_pos_neg_samples_from_single_article(word_dict, article_idx, article_buffer, negative_sampling_rate=1):\n",
    "    sample_list = []\n",
    "    # generate positive samples\n",
    "    sent_list = readlines_from_article(article_buffer[article_idx])\n",
    "    art_len = len(sent_list)\n",
    "    idx = random.randint(0, art_len-1)\n",
    "    wrapper_text_idx = list(range(idx)) + list(range((idx+1) % art_len, art_len))\n",
    "    wrapper_text_list = sent_list[:idx] + sent_list[(idx+1) % art_len : art_len]\n",
    "    wrapper_tokens = []\n",
    "    for sent1 in wrapper_text_list:\n",
    "        wrapper_tokens += sentence_to_integers(sent1, word_dict)\n",
    "    sent_tokens = sentence_to_integers(sent_list[idx], word_dict)\n",
    "    sample_list.append({'in0':sent_tokens, 'in1':wrapper_tokens, 'label':1})\n",
    "    # generate negative sample\n",
    "    buff_len = len(article_buffer)\n",
    "    sampled_inds = np.random.choice(list(range(article_idx)) + list(range((article_idx+1) % buff_len, buff_len)), \n",
    "                                    size=negative_sampling_rate)\n",
    "    for n_idx in sampled_inds:\n",
    "        other_article = article_buffer[n_idx]\n",
    "        context_list = readlines_from_article(other_article)\n",
    "        context_tokens = []\n",
    "        for sent2 in context_list:\n",
    "            context_tokens += sentence_to_integers(sent2, word_dict)\n",
    "        sample_list.append({'in0': sent_tokens, 'in1':context_tokens, 'label':0})\n",
    "    return sample_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build dev and test data\n",
    "for data in ['dev10k', 'test10k']:\n",
    "    fname = os.path.join(datadir,'wikipedia_{}.txt'.format(data))\n",
    "    test_nsr = 5\n",
    "    outfname = '{}_tokenized-nsr{}.jsonl'.format(data, test_nsr)\n",
    "    article_buffer = list(get_article_iter_from_file(fname))\n",
    "    sample_buffer = []\n",
    "    for article_idx in range(len(article_buffer)):\n",
    "        sample_buffer += generate_pos_neg_samples_from_single_article(w_dict, article_idx, \n",
    "                                                                      article_buffer, \n",
    "                                                                      negative_sampling_rate=test_nsr)\n",
    "    with jsonlines.open(outfname, 'w') as writer:\n",
    "        writer.write_all(sample_buffer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Upload preprocessed data to S3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "TRAIN_DATA=\"train250k_tokenized_shuf.jsonl\"\n",
    "DEV_DATA=\"dev10k_tokenized-nsr{}.jsonl\".format(test_nsr)\n",
    "TEST_DATA=\"test10k_tokenized-nsr{}.jsonl\".format(test_nsr)\n",
    "\n",
    "# NOTE: define your s3 bucket and key here\n",
    "S3_BUCKET = '<YOUR S3 BUCKET>'\n",
    "S3_KEY = 'object2vec-doc2vec'\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%bash -s \"$TRAIN_DATA\" \"$DEV_DATA\" \"$TEST_DATA\" \"$S3_BUCKET\" \"$S3_KEY\"\n",
    "\n",
    "aws s3 cp \"$1\" s3://$4/$5/input/train/\n",
    "aws s3 cp \"$2\" s3://$4/$5/input/validation/\n",
    "aws s3 cp \"$3\" s3://$4/$5/input/test/"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define Sagemaker session, Object2Vec image, S3 input and output paths"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sagemaker import get_execution_role\n",
    "from sagemaker.amazon.amazon_estimator import get_image_uri\n",
    "\n",
    "\n",
    "region = boto3.Session().region_name\n",
    "print(\"Your notebook is running on region '{}'\".format(region))\n",
    "\n",
    "sess = sagemaker.Session()\n",
    "\n",
    " \n",
    "role = get_execution_role()\n",
    "print(\"Your IAM role: '{}'\".format(role))\n",
    "\n",
    "container = get_image_uri(region, 'object2vec')\n",
    "print(\"The image uri used is '{}'\".format(container))\n",
    "\n",
    "print(\"Using s3 buceket: {} and key prefix: {}\".format(S3_BUCKET, S3_KEY))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## define input channels\n",
    "\n",
    "s3_input_path = os.path.join('s3://', S3_BUCKET, S3_KEY, 'input')\n",
    "\n",
    "s3_train = s3_input(os.path.join(s3_input_path, 'train', TRAIN_DATA), \n",
    "                    distribution='ShardedByS3Key', content_type='application/jsonlines')\n",
    "\n",
    "s3_valid = s3_input(os.path.join(s3_input_path, 'validation', DEV_DATA), \n",
    "                    distribution='ShardedByS3Key', content_type='application/jsonlines')\n",
    "\n",
    "s3_test = s3_input(os.path.join(s3_input_path, 'test', TEST_DATA), \n",
    "                   distribution='ShardedByS3Key', content_type='application/jsonlines')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## define output path\n",
    "output_path = os.path.join('s3://', S3_BUCKET, S3_KEY, 'models')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train and deploy doc2vec"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We combine four new features into our training of Object2Vec:\n",
    "\n",
    "- Negative sampling: With the new `negative_sampling_rate` hyperparameter, users of Object2Vec only need to provide positively labeled data pairs, and the algorithm automatically samples for negative data internally during training.\n",
    "\n",
    "- Weight-sharing of embedding layer: The new `tied_token_embedding_weight` hyperparameter gives user the flexibility to share the embedding weights for both encoders, and it improves the performance of the algorithm in this use-case\n",
    "\n",
    "- The new `comparator_list` hyperparameter gives users the flexibility to mix-and-match different operators so that they can tune the algorithm towards optimal performance for their applications."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Learning performance boost with new features"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "_Table 1_ below shows the effect of these features on these two metrics evaluated on a test set obtained from the same data creation process. \n",
    "\n",
    "We see that when negative sampling and weight-sharing of embedding layer is on, and when we use a customized comparator operator (Hadamard product), the model has improved test performance. When all these features are combined together (last row of Table 1), the algorithm has the best performance as measured by accuracy and cross-entropy.\n",
    "\n",
    "\n",
    "### Table 1\n",
    "\n",
    "|negative_sampling_rate|weight-sharing|comparator operator| Test accuracy | Test cross-entropy|\n",
    "| :-------------       | :----------: | :-----------:     | :----------:  | ----------:       |\n",
    "|  off                 | off          | default           | 0.167         |  23               |\n",
    "|  3                 | off          | default             | 0.92          |  0.21             |\n",
    "|  5                 | off          | default             | 0.92          |   0.19            |\n",
    "|  off               | on           | default           | 0.167         |  23               |\n",
    "|  3                 | on           | default           | 0.93         |  0.18               |\n",
    "|  5                 | on           | default           | 0.936         |  0.17               |\n",
    "|  off               | on           | customized        | 0.17         |  23               |\n",
    "|  3                 | on           | customized        | 0.93         |  0.18               |\n",
    "|  5                 | on           | customized        | 0.94         |  0.17               |\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "- The new `token_embedding_storage_type` hyperparameter flags the use of sparse gradient update, which takes advantage of the sparse input format of Object2Vec. We tested and summarized the training speedup with different GPU and `max_seq_len` configurations in the table below. In a word, we see 2-20 times speed up on different machine and algorithm configurations."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training speedup with sparse gradient update"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "_Table 2_ below shows the training speeds up with sparse gradient update feature turned on, as a function of number of GPUs used for training.\n",
    "\n",
    "### Table 2\n",
    "\n",
    "|num_gpus|Throughput (samples/sec) with dense storage|Throughput with sparse storage|max_seq_len (in0/in1)|Speedup X-times  |\n",
    "| :------------- | :----------: | :-----------:| :----------: | ----------: |\n",
    "|  1             | 5k           | 14k          | 50           |  2.8        |\n",
    "|  2             | 2.7k         | 23k          | 50           |  8.5        |\n",
    "|  3             | 2k           | 23~26k       | 50           |  10         |\n",
    "|  4             | 2k           | 23k          | 50           |  10         |\n",
    "|  8             | 1.1k         | 19k~20k      | 50           |  20         |\n",
    "|  1             | 1.1k         | 2k           | 500          |  2          |\n",
    "|  2             | 1.5k         | 3.6k         | 500          |  2.4        |\n",
    "|  4             | 1.6k         | 6k           | 500          |  3.75       |\n",
    "|  6             | 1.3k         | 6.7k         | 500          |  5.15       |\n",
    "|  8             | 1.1k        | 5.6k         | 500          |  5          |"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define training hyperparameters\n",
    "\n",
    "hyperparameters = {\n",
    "      \"_kvstore\": \"device\",\n",
    "      \"_num_gpus\": 'auto',\n",
    "      \"_num_kv_servers\": \"auto\",\n",
    "      \"bucket_width\": 0,\n",
    "      \"dropout\": 0.4,\n",
    "      \"early_stopping_patience\": 2,\n",
    "      \"early_stopping_tolerance\": 0.01,\n",
    "      \"enc0_layers\": \"auto\",\n",
    "      \"enc0_max_seq_len\": 50,\n",
    "      \"enc0_network\": \"pooled_embedding\",\n",
    "      \"enc0_pretrained_embedding_file\": \"\",\n",
    "      \"enc0_token_embedding_dim\": 300,\n",
    "      \"enc0_vocab_size\": 267522,\n",
    "      \"enc1_network\": \"enc0\",\n",
    "      \"enc_dim\": 300,\n",
    "      \"epochs\": 20,\n",
    "      \"learning_rate\": 0.01,\n",
    "      \"mini_batch_size\": 512,\n",
    "      \"mlp_activation\": \"relu\",\n",
    "      \"mlp_dim\": 512,\n",
    "      \"mlp_layers\": 2,\n",
    "      \"num_classes\": 2,\n",
    "      \"optimizer\": \"adam\",\n",
    "      \"output_layer\": \"softmax\",\n",
    "      \"weight_decay\": 0\n",
    "}\n",
    "\n",
    "\n",
    "hyperparameters['negative_sampling_rate'] = 3\n",
    "hyperparameters['tied_token_embedding_weight'] = \"true\"\n",
    "hyperparameters['comparator_list'] = \"hadamard\"\n",
    "hyperparameters['token_embedding_storage_type'] = 'row_sparse'\n",
    "\n",
    "    \n",
    "# get estimator\n",
    "doc2vec = sagemaker.estimator.Estimator(container,\n",
    "                                          role, \n",
    "                                          train_instance_count=1, \n",
    "                                          train_instance_type='ml.p2.xlarge',\n",
    "                                          output_path=output_path,\n",
    "                                          sagemaker_session=sess)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# set hyperparameters\n",
    "doc2vec.set_hyperparameters(**hyperparameters)\n",
    "\n",
    "# fit estimator with data\n",
    "doc2vec.fit({'train': s3_train, 'validation':s3_valid, 'test':s3_test})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# deploy model\n",
    "\n",
    "doc2vec_model = doc2vec.create_model(\n",
    "                        serializer=json_serializer,\n",
    "                        deserializer=json_deserializer,\n",
    "                        content_type='application/json')\n",
    "\n",
    "predictor = doc2vec_model.deploy(initial_instance_count=1, instance_type='ml.m4.xlarge')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Apply learned embeddings to document retrieval task\n",
    "\n",
    "After training the model, we can use the encoders in Object2Vec to map new articles and sentences into a shared embedding space. Then we evaluate the quality of these embeddings with a downstream document retrieval task.\n",
    "\n",
    "In the retrieval task, given a sentence query, the trained algorithm needs to find its best matching document (the ground-truth document is the one that contains it) from a pool of documents, where the pool contains 10,000 other non ground-truth documents. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_tokenized_articles_from_single_file(fname, word_dict):\n",
    "    for article in get_article_iter_from_file(fname):\n",
    "        integer_article = []\n",
    "        for sent in readlines_from_article(article):\n",
    "            integer_article += sentence_to_integers(sent, word_dict)\n",
    "        yield integer_article"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_jsonline(fname):\n",
    "    \"\"\"\n",
    "    Reads jsonline files and returns iterator\n",
    "    \"\"\"\n",
    "    with jsonlines.open(fname) as reader:\n",
    "        for line in reader:\n",
    "            yield line\n",
    "\n",
    "def send_payload(predictor, payload):\n",
    "    return predictor.predict(payload)\n",
    "\n",
    "def write_to_jsonlines(data, fname):\n",
    "    with jsonlines.open(fname, 'a') as writer:\n",
    "        data = data['predictions']\n",
    "        writer.write_all(data)\n",
    "\n",
    "\n",
    "def eval_and_write(predictor, fname, to_fname,  batch_size):\n",
    "    if os.path.exists(to_fname):\n",
    "        print(\"Removing exisiting embedding file {}\".format(to_fname))\n",
    "        os.remove(to_fname)\n",
    "    print(\"Getting embedding of data in {} and store to {}...\".format(fname, to_fname))\n",
    "    test_data_content = list(read_jsonline(fname))\n",
    "    n_test = len(test_data_content)\n",
    "    n_batches = math.ceil(n_test / float(batch_size))\n",
    "    start = 0\n",
    "    for idx in range(n_batches):\n",
    "        if idx % 10 == 0:\n",
    "            print(\"Inference on the {}-th batch\".format(idx+1))\n",
    "        end = (start + batch_size) if (start + batch_size) <= n_test else n_test\n",
    "        payload = {'instances': test_data_content[start:end]}\n",
    "        data = send_payload(predictor, payload)\n",
    "        write_to_jsonlines(data, to_fname)\n",
    "        start = end\n",
    "\n",
    "def get_embeddings(predictor, test_data_content, batch_size):\n",
    "    n_test = len(test_data_content)\n",
    "    n_batches = math.ceil(n_test / float(batch_size))\n",
    "    start = 0\n",
    "    embeddings = []\n",
    "    for idx in range(n_batches):\n",
    "        if idx % 10 == 0:\n",
    "            print(\"Inference the {}-th batch\".format(idx+1))\n",
    "        end = (start + batch_size) if (start + batch_size) <= n_test else n_test\n",
    "        payload = {'instances': test_data_content[start:end]}\n",
    "        data = send_payload(predictor, payload)\n",
    "        embeddings += data['predictions']\n",
    "        start = end\n",
    "    return embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "basedocs_fpath = os.path.join(datadir, 'wikipedia_test_basedocs.txt')\n",
    "test_fpath = '{}_tokenized-nsr{}.jsonl'.format('test10k', test_nsr)\n",
    "eval_basedocs = 'test_basedocs_tokenized_in0.jsonl'\n",
    "basedocs_emb = 'test_basedocs_embeddings.jsonl'\n",
    "sent_doc_emb = 'test10k_embeddings_pairs.jsonl'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import jsonlines\n",
    "import numpy as np\n",
    "basedocs_emb = 'test_basedocs_embeddings.jsonl'\n",
    "sent_doc_emb = 'test10k_embeddings_pairs.jsonl'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 100\n",
    "\n",
    "# tokenize basedocs\n",
    "with jsonlines.open(eval_basedocs, 'w') as writer:\n",
    "    for data in generate_tokenized_articles_from_single_file(basedocs_fpath, w_dict):\n",
    "        writer.write({'in0': data})\n",
    "\n",
    "# get basedocs embedding\n",
    "eval_and_write(predictor, eval_basedocs, basedocs_emb, batch_size)\n",
    "\n",
    "\n",
    "# get embeddings for sentence and ground-truth article pairs\n",
    "sentences = []\n",
    "gt_articles = []\n",
    "for data in read_jsonline(test_fpath):\n",
    "    if data['label'] == 1:\n",
    "        sentences.append({'in0': data['in0']})\n",
    "        gt_articles.append({'in0': data['in1']})\n",
    "        \n",
    "sent_emb = get_embeddings(predictor, sentences, batch_size)\n",
    "doc_emb = get_embeddings(predictor, gt_articles, batch_size)\n",
    "\n",
    "with jsonlines.open(sent_doc_emb, 'w') as writer:\n",
    "    for (sent, doc) in zip(sent_emb, doc_emb):\n",
    "        writer.write({'sent': sent['embeddings'], 'doc': doc['embeddings']})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "del w_dict\n",
    "del sent_emb, doc_emb"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The blocks below evaluate the performance of Object2Vec model on the document retrieval task.\n",
    "\n",
    "We use two metrics hits@k and mean rank to evaluate the retrieval performance. Note that the ground-truth documents in the pool have the query sentence removed from them -- else the task would have been trivial.\n",
    "\n",
    "* hits@k:  It calculates the fraction of queries where its best-matching (ground-truth) document is contained in top k retrieved documents by the algorithm.\n",
    "* mean rank: It is the average rank of the best-matching documents, as determined by the algorithm, over all queries."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Construct normalized basedocs, sentences, and ground-truth docs embedding matrix\n",
    "\n",
    "basedocs = []\n",
    "with jsonlines.open(basedocs_emb) as reader:\n",
    "    for line in reader:\n",
    "        basedocs.append(np.array(line['embeddings'])) \n",
    "\n",
    "\n",
    "sent_embs = []\n",
    "gt_doc_embs = []\n",
    "\n",
    "with jsonlines.open(sent_doc_emb) as reader2:\n",
    "    for line2 in reader2:\n",
    "        sent_embs.append(line2['sent'])\n",
    "        gt_doc_embs.append(line2['doc'])\n",
    "\n",
    "basedocs_emb_mat = normalize(np.array(basedocs).T, axis=0)\n",
    "sent_emb_mat = normalize(np.array(sent_embs), axis=1)\n",
    "gt_emb_mat = normalize(np.array(gt_doc_embs).T, axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_chunk_query_rank(sent_emb_mat, basedocs_emb_mat, gt_emb_mat, largest_k):\n",
    "    # this is a memory-consuming step if chunk is large\n",
    "    dot_with_basedocs = np.matmul(sent_emb_mat, basedocs_emb_mat)\n",
    "    dot_with_gt = np.diag(np.matmul(sent_emb_mat, gt_emb_mat))\n",
    "    final_ranking_scores = np.insert(dot_with_basedocs, 0, dot_with_gt, axis=1)\n",
    "    query_rankings = list()\n",
    "    largest_k_list = list()\n",
    "    for row in final_ranking_scores:\n",
    "        ranking_ind = np.argsort(row) # sorts row in increasing order of similarity score\n",
    "        num_scores = len(ranking_ind)\n",
    "        query_rankings.append(num_scores-list(ranking_ind).index(0))\n",
    "        largest_k_list.append(np.array(ranking_ind[-largest_k:]).astype(int))\n",
    "    return query_rankings, largest_k_list\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`Note: We evaluate the learned embeddings on chunks of test sentences-document pairs to save run-time memory; this is to make sure that our code works on the smallest notebook instance *ml.t2.medium*. If you have a larger notebook instance, you can increase the chunk_size to speed up evaluation. For instances larger than ml.t2.xlarge, you can set chunk_size = num_test_samples`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "chunk_size = 1000\n",
    "num_test_samples = len(sent_embs)\n",
    "assert num_test_samples%chunk_size == 0, \"Chunk_size must be divisible by {}\".format(num_test_samples)\n",
    "num_chunks = int(num_test_samples / chunk_size)\n",
    "k_list = [1, 5, 10, 20, 50]\n",
    "largest_k = max(k_list)\n",
    "query_all_rankings = list()\n",
    "all_largest_k_list = list()\n",
    "\n",
    "for i in range(0, num_chunks*chunk_size, chunk_size):\n",
    "    print(\"Evaluating on the {}-th chunk\".format(i))\n",
    "    j = i+chunk_size\n",
    "    sent_emb_submat = sent_emb_mat[i:j, :]\n",
    "    gt_emb_submat = gt_emb_mat[:, i:j]\n",
    "    query_rankings, largest_k_list = get_chunk_query_rank(sent_emb_submat, basedocs_emb_mat, gt_emb_submat, largest_k)\n",
    "    query_all_rankings += query_rankings\n",
    "    all_largest_k_list.append(np.array(largest_k_list).astype(int))\n",
    "\n",
    "all_largest_k_mat = np.concatenate(all_largest_k_list, axis=0).astype(int)\n",
    "\n",
    "print(\"Summary:\")\n",
    "print(\"Mean query ranks is {}\".format(np.mean(query_all_rankings)))\n",
    "print(\"Percentiles of query ranks is 50%:{}, 80%:{}, 90%:{}, 99%:{}\".format(*np.percentile(query_all_rankings, [50, 80, 90, 99])))\n",
    "\n",
    "for k in k_list:\n",
    "    top_k_mat = all_largest_k_mat[:, -k:]\n",
    "    unique, counts = np.unique(top_k_mat, return_counts=True)\n",
    "    print(\"The hits at {} score is {}/{}\".format(k, counts[0], len(top_k_mat)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Comparison with the StarSpace algorithm \n",
    "\n",
    "We compare the performance of Object2Vec with the StarSpace (https://github.com/facebookresearch/StarSpace) algorithm on the document retrieval evaluation task, using a set of 250 thousand Wikipedia documents. The experimental results displayed in the table below, show that Object2Vec significantly outperforms StarSpace on all metrics although both models use the same kind of encoders for sentences and documents.\n",
    "\n",
    "\n",
    "| Algorithm      | hits@1       | hits@10      | hits@20      |  mean rank  |\n",
    "| :------------- | :----------: | :-----------:| :----------: | ----------: |\n",
    "|  StarSpace     | 21.98%       | 42.77%       | 50.55%       |  303.34     |\n",
    "|  Object2Vec    | 26.40%       | 47.42%       | 53.83%       |  248.67     |\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "predictor.delete_endpoint()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_python3",
   "language": "python",
   "name": "conda_python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
